home *** CD-ROM | disk | FTP | other *** search
/ Cream of the Crop 26 / Cream of the Crop 26.iso / os2 / octa209s.zip / octave-2.09 / liboctave / dDiagMatrix.cc < prev    next >
C/C++ Source or Header  |  1996-10-12  |  9KB  |  471 lines

  1. // DiagMatrix manipulations.
  2. /*
  3.  
  4. Copyright (C) 1996 John W. Eaton
  5.  
  6. This file is part of Octave.
  7.  
  8. Octave is free software; you can redistribute it and/or modify it
  9. under the terms of the GNU General Public License as published by the
  10. Free Software Foundation; either version 2, or (at your option) any
  11. later version.
  12.  
  13. Octave is distributed in the hope that it will be useful, but WITHOUT
  14. ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  15. FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  16. for more details.
  17.  
  18. You should have received a copy of the GNU General Public License
  19. along with Octave; see the file COPYING.  If not, write to the Free
  20. Software Foundation, 59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.
  21.  
  22. */
  23.  
  24. #if defined (__GNUG__)
  25. #pragma implementation
  26. #endif
  27.  
  28. #ifdef HAVE_CONFIG_H
  29. #include <config.h>
  30. #endif
  31.  
  32. #include <iostream.h>
  33.  
  34. #include "lo-error.h"
  35. #include "mx-base.h"
  36. #include "mx-inlines.cc"
  37. #include "oct-cmplx.h"
  38.  
  39. // Diagonal Matrix class.
  40.  
  41. bool
  42. DiagMatrix::operator == (const DiagMatrix& a) const
  43. {
  44.   if (rows () != a.rows () || cols () != a.cols ())
  45.     return 0;
  46.  
  47.   return equal (data (), a.data (), length ());
  48. }
  49.  
  50. bool
  51. DiagMatrix::operator != (const DiagMatrix& a) const
  52. {
  53.   return !(*this == a);
  54. }
  55.  
  56. DiagMatrix&
  57. DiagMatrix::fill (double val)
  58. {
  59.   for (int i = 0; i < length (); i++)
  60.     elem (i, i) = val;
  61.   return *this;
  62. }
  63.  
  64. DiagMatrix&
  65. DiagMatrix::fill (double val, int beg, int end)
  66. {
  67.   if (beg < 0 || end >= length () || end < beg)
  68.     {
  69.       (*current_liboctave_error_handler) ("range error for fill");
  70.       return *this;
  71.     }
  72.  
  73.   for (int i = beg; i <= end; i++)
  74.     elem (i, i) = val;
  75.  
  76.   return *this;
  77. }
  78.  
  79. DiagMatrix&
  80. DiagMatrix::fill (const ColumnVector& a)
  81. {
  82.   int len = length ();
  83.   if (a.length () != len)
  84.     {
  85.       (*current_liboctave_error_handler) ("range error for fill");
  86.       return *this;
  87.     }
  88.  
  89.   for (int i = 0; i < len; i++)
  90.     elem (i, i) = a.elem (i);
  91.  
  92.   return *this;
  93. }
  94.  
  95. DiagMatrix&
  96. DiagMatrix::fill (const RowVector& a)
  97. {
  98.   int len = length ();
  99.   if (a.length () != len)
  100.     {
  101.       (*current_liboctave_error_handler) ("range error for fill");
  102.       return *this;
  103.     }
  104.  
  105.   for (int i = 0; i < len; i++)
  106.     elem (i, i) = a.elem (i);
  107.  
  108.   return *this;
  109. }
  110.  
  111. DiagMatrix&
  112. DiagMatrix::fill (const ColumnVector& a, int beg)
  113. {
  114.   int a_len = a.length ();
  115.   if (beg < 0 || beg + a_len >= length ())
  116.     {
  117.       (*current_liboctave_error_handler) ("range error for fill");
  118.       return *this;
  119.     }
  120.  
  121.   for (int i = 0; i < a_len; i++)
  122.     elem (i+beg, i+beg) = a.elem (i);
  123.  
  124.   return *this;
  125. }
  126.  
  127. DiagMatrix&
  128. DiagMatrix::fill (const RowVector& a, int beg)
  129. {
  130.   int a_len = a.length ();
  131.   if (beg < 0 || beg + a_len >= length ())
  132.     {
  133.       (*current_liboctave_error_handler) ("range error for fill");
  134.       return *this;
  135.     }
  136.  
  137.   for (int i = 0; i < a_len; i++)
  138.     elem (i+beg, i+beg) = a.elem (i);
  139.  
  140.   return *this;
  141. }
  142.  
  143. DiagMatrix
  144. DiagMatrix::transpose (void) const
  145. {
  146.   return DiagMatrix (dup (data (), length ()), cols (), rows ());
  147. }
  148.  
  149. DiagMatrix
  150. real (const ComplexDiagMatrix& a)
  151. {
  152.   DiagMatrix retval;
  153.   int a_len = a.length ();
  154.   if (a_len > 0)
  155.     retval = DiagMatrix (real_dup (a.data (), a_len), a.rows (),
  156.              a.cols ());
  157.   return retval;
  158. }
  159.  
  160. DiagMatrix
  161. imag (const ComplexDiagMatrix& a)
  162. {
  163.   DiagMatrix retval;
  164.   int a_len = a.length ();
  165.   if (a_len > 0)
  166.     retval = DiagMatrix (imag_dup (a.data (), a_len), a.rows (),
  167.              a.cols ());
  168.   return retval;
  169. }
  170.  
  171. Matrix
  172. DiagMatrix::extract (int r1, int c1, int r2, int c2) const
  173. {
  174.   if (r1 > r2) { int tmp = r1; r1 = r2; r2 = tmp; }
  175.   if (c1 > c2) { int tmp = c1; c1 = c2; c2 = tmp; }
  176.  
  177.   int new_r = r2 - r1 + 1;
  178.   int new_c = c2 - c1 + 1;
  179.  
  180.   Matrix result (new_r, new_c);
  181.  
  182.   for (int j = 0; j < new_c; j++)
  183.     for (int i = 0; i < new_r; i++)
  184.       result.elem (i, j) = elem (r1+i, c1+j);
  185.  
  186.   return result;
  187. }
  188.  
  189. // extract row or column i.
  190.  
  191. RowVector
  192. DiagMatrix::row (int i) const
  193. {
  194.   int nr = rows ();
  195.   int nc = cols ();
  196.   if (i < 0 || i >= nr)
  197.     {
  198.       (*current_liboctave_error_handler) ("invalid row selection");
  199.       return RowVector (); 
  200.     }
  201.  
  202.   RowVector retval (nc, 0.0);
  203.   if (nr <= nc || (nr > nc && i < nc))
  204.     retval.elem (i) = elem (i, i);
  205.  
  206.   return retval;
  207. }
  208.  
  209. RowVector
  210. DiagMatrix::row (char *s) const
  211. {
  212.   if (! s)
  213.     {
  214.       (*current_liboctave_error_handler) ("invalid row selection");
  215.       return RowVector (); 
  216.     }
  217.  
  218.   char c = *s;
  219.   if (c == 'f' || c == 'F')
  220.     return row (0);
  221.   else if (c == 'l' || c == 'L')
  222.     return row (rows () - 1);
  223.   else
  224.     {
  225.       (*current_liboctave_error_handler) ("invalid row selection");
  226.       return RowVector (); 
  227.     }
  228. }
  229.  
  230. ColumnVector
  231. DiagMatrix::column (int i) const
  232. {
  233.   int nr = rows ();
  234.   int nc = cols ();
  235.   if (i < 0 || i >= nc)
  236.     {
  237.       (*current_liboctave_error_handler) ("invalid column selection");
  238.       return ColumnVector (); 
  239.     }
  240.  
  241.   ColumnVector retval (nr, 0.0);
  242.   if (nr >= nc || (nr < nc && i < nr))
  243.     retval.elem (i) = elem (i, i);
  244.  
  245.   return retval;
  246. }
  247.  
  248. ColumnVector
  249. DiagMatrix::column (char *s) const
  250. {
  251.   if (! s)
  252.     {
  253.       (*current_liboctave_error_handler) ("invalid column selection");
  254.       return ColumnVector (); 
  255.     }
  256.  
  257.   char c = *s;
  258.   if (c == 'f' || c == 'F')
  259.     return column (0);
  260.   else if (c == 'l' || c == 'L')
  261.     return column (cols () - 1);
  262.   else
  263.     {
  264.       (*current_liboctave_error_handler) ("invalid column selection");
  265.       return ColumnVector (); 
  266.     }
  267. }
  268.  
  269. DiagMatrix
  270. DiagMatrix::inverse (void) const
  271. {
  272.   int info;
  273.   return inverse (info);
  274. }
  275.  
  276. DiagMatrix
  277. DiagMatrix::inverse (int &info) const
  278. {
  279.   int nr = rows ();
  280.   int nc = cols ();
  281.   int len = length ();
  282.   if (nr != nc)
  283.     {
  284.       (*current_liboctave_error_handler) ("inverse requires square matrix");
  285.       return DiagMatrix ();
  286.     }
  287.  
  288.   DiagMatrix retval (nr, nc);
  289.  
  290.   info = 0;
  291.   for (int i = 0; i < len; i++)
  292.     {
  293.       if (elem (i, i) == 0.0)
  294.     {
  295.       info = -1;
  296.       return *this;
  297.     }
  298.       else
  299.     retval.elem (i, i) = 1.0 / elem (i, i);
  300.     }
  301.  
  302.   return retval;
  303. }
  304.  
  305. // diagonal matrix by diagonal matrix -> diagonal matrix operations
  306.  
  307. DiagMatrix&
  308. DiagMatrix::operator += (const DiagMatrix& a)
  309. {
  310.   int nr = rows ();
  311.   int nc = cols ();
  312.  
  313.   int a_nr = a.rows ();
  314.   int a_nc = a.cols ();
  315.  
  316.   if (nr != a_nr || nc != a_nc)
  317.     {
  318.       gripe_nonconformant ("operator +=", nr, nc, a_nr, a_nc);
  319.       return *this;
  320.     }
  321.  
  322.   if (nc == 0 || nr == 0)
  323.     return *this;
  324.  
  325.   double *d = fortran_vec (); // Ensures only one reference to my privates!
  326.  
  327.   add2 (d, a.data (), length ());
  328.   return *this;
  329. }
  330.  
  331. DiagMatrix&
  332. DiagMatrix::operator -= (const DiagMatrix& a)
  333. {
  334.   int nr = rows ();
  335.   int nc = cols ();
  336.  
  337.   int a_nr = a.rows ();
  338.   int a_nc = a.cols ();
  339.  
  340.   if (nr != a_nr || nc != a_nc)
  341.     {
  342.       gripe_nonconformant ("operator -=", nr, nc, a_nr, a_nc);
  343.       return *this;
  344.     }
  345.  
  346.   if (nr == 0 || nc == 0)
  347.     return *this;
  348.  
  349.   double *d = fortran_vec (); // Ensures only one reference to my privates!
  350.  
  351.   subtract2 (d, a.data (), length ());
  352.   return *this;
  353. }
  354.  
  355. // diagonal matrix by diagonal matrix -> diagonal matrix operations
  356.  
  357. DiagMatrix
  358. operator * (const DiagMatrix& a, const DiagMatrix& b)
  359. {
  360.   int nr_a = a.rows ();
  361.   int nc_a = a.cols ();
  362.  
  363.   int nr_b = b.rows ();
  364.   int nc_b = b.cols ();
  365.  
  366.   if (nc_a != nr_b)
  367.     {
  368.       gripe_nonconformant ("operaotr *", nr_a, nc_a, nr_b, nc_b);
  369.       return DiagMatrix ();
  370.     }
  371.  
  372.   if (nr_a == 0 || nc_a == 0 || nc_b == 0)
  373.     return DiagMatrix (nr_a, nc_a, 0.0);
  374.  
  375.   DiagMatrix c (nr_a, nc_b);
  376.  
  377.   int len = nr_a < nc_b ? nr_a : nc_b;
  378.  
  379.   for (int i = 0; i < len; i++)
  380.     {
  381.       double a_element = a.elem (i, i);
  382.       double b_element = b.elem (i, i);
  383.  
  384.       if (a_element == 0.0 || b_element == 0.0)
  385.         c.elem (i, i) = 0.0;
  386.       else if (a_element == 1.0)
  387.         c.elem (i, i) = b_element;
  388.       else if (b_element == 1.0)
  389.         c.elem (i, i) = a_element;
  390.       else
  391.         c.elem (i, i) = a_element * b_element;
  392.     }
  393.  
  394.   return c;
  395. }
  396.  
  397. // other operations
  398.  
  399. ColumnVector
  400. DiagMatrix::diag (void) const
  401. {
  402.   return diag (0);
  403. }
  404.  
  405. // Could be optimized...
  406.  
  407. ColumnVector
  408. DiagMatrix::diag (int k) const
  409. {
  410.   int nnr = rows ();
  411.   int nnc = cols ();
  412.   if (k > 0)
  413.     nnc -= k;
  414.   else if (k < 0)
  415.     nnr += k;
  416.  
  417.   ColumnVector d;
  418.  
  419.   if (nnr > 0 && nnc > 0)
  420.     {
  421.       int ndiag = (nnr < nnc) ? nnr : nnc;
  422.  
  423.       d.resize (ndiag);
  424.  
  425.       if (k > 0)
  426.     {
  427.       for (int i = 0; i < ndiag; i++)
  428.         d.elem (i) = elem (i, i+k);
  429.     }
  430.       else if ( k < 0)
  431.     {
  432.       for (int i = 0; i < ndiag; i++)
  433.         d.elem (i) = elem (i-k, i);
  434.     }
  435.       else
  436.     {
  437.       for (int i = 0; i < ndiag; i++)
  438.         d.elem (i) = elem (i, i);
  439.     }
  440.     }
  441.   else
  442.     cerr << "diag: requested diagonal out of range\n";
  443.  
  444.   return d;
  445. }
  446.  
  447. ostream&
  448. operator << (ostream& os, const DiagMatrix& a)
  449. {
  450. //  int field_width = os.precision () + 7;
  451.  
  452.   for (int i = 0; i < a.rows (); i++)
  453.     {
  454.       for (int j = 0; j < a.cols (); j++)
  455.     {
  456.       if (i == j)
  457.         os << " " /* setw (field_width) */ << a.elem (i, i);
  458.       else
  459.         os << " " /* setw (field_width) */ << 0.0;
  460.     }
  461.       os << "\n";
  462.     }
  463.   return os;
  464. }
  465.  
  466. /*
  467. ;;; Local Variables: ***
  468. ;;; mode: C++ ***
  469. ;;; End: ***
  470. */
  471.